{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 align=\"center\">Registration Settings: Choices, Choices, Choices</h1>\n",
    "\n",
    "The performance of most registration algorithms is dependent on a large number of parameter settings. For optimal performance you will need to customize your settings, turning all the knobs to their \"optimal\" position:<br>\n",
    "<img src=\"knobs.jpg\" style=\"width:700px\"/>\n",
    "<font size=\"1\"> [This image was originally posted to Flickr and downloaded from wikimedia commons https://commons.wikimedia.org/wiki/File:TASCAM_M-520_knobs.jpg]</font>\n",
    "\n",
    "This notebook illustrates the use of reference data (a.k.a \"gold\" standard) to empirically tune a registration framework for specific usage. This is dependent on the characteristics of your images (anatomy, modality, image's physical spacing...) and on the clinical needs.\n",
    "\n",
    "Also keep in mind that the definition of optimal settings does not necessarily correspond to those that provide the most accurate results. \n",
    "\n",
    "The optimal settings are task specific and should provide:\n",
    "<ul>\n",
    "<li>Sufficient accuracy in the Region Of Interest (ROI).</li>\n",
    "<li>Complete the computation in the alloted time.</li>\n",
    "</ul>\n",
    "\n",
    "We will be using the training data from the  Retrospective Image Registration Evaluation (<a href=\"http://www.insight-journal.org/rire/\">RIRE</a>) project."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import SimpleITK as sitk\n",
    "\n",
    "# Utility method that either downloads data from the network or\n",
    "# if already downloaded returns the file name for reading from disk (cached data).\n",
    "%run update_path_to_download_script\n",
    "from downloaddata import fetch_data as fdata\n",
    "\n",
    "# Always write output to a separate directory, we don't want to pollute the source directory. \n",
    "OUTPUT_DIR = 'Output'\n",
    "\n",
    "%matplotlib inline\n",
    "import registration_callbacks as rc\n",
    "import registration_utilities as ru\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read the RIRE data and generate a larger point set as a reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "fixed_image =  sitk.ReadImage(fdata(\"training_001_ct.mha\"), sitk.sitkFloat32)\n",
    "moving_image = sitk.ReadImage(fdata(\"training_001_mr_T1.mha\"), sitk.sitkFloat32) \n",
    "fixed_fiducial_points, moving_fiducial_points = ru.load_RIRE_ground_truth(fdata(\"ct_T1.standard\"))\n",
    "\n",
    "# Estimate the reference_transform defined by the RIRE fiducials and check that the FRE makes sense (low) \n",
    "R, t = ru.absolute_orientation_m(fixed_fiducial_points, moving_fiducial_points)\n",
    "reference_transform = sitk.Euler3DTransform()\n",
    "reference_transform.SetMatrix(R.flatten())\n",
    "reference_transform.SetTranslation(t)\n",
    "reference_errors_mean, reference_errors_std, _, reference_errors_max,_ = ru.registration_errors(reference_transform, fixed_fiducial_points, moving_fiducial_points)\n",
    "print('Reference data errors (FRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(reference_errors_mean, reference_errors_std, reference_errors_max))\n",
    "\n",
    "# Generate a reference dataset from the reference transformation \n",
    "# (corresponding points in the fixed and moving images).\n",
    "fixed_points = ru.generate_random_pointset(image=fixed_image, num_points=100)\n",
    "moving_points = [reference_transform.TransformPoint(p) for p in fixed_points]    \n",
    "\n",
    "# Compute the TRE prior to registration.\n",
    "pre_errors_mean, pre_errors_std, pre_errors_min, pre_errors_max, _ = ru.registration_errors(sitk.Euler3DTransform(), fixed_points, moving_points, display_errors = True)\n",
    "print('Before registration, errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(pre_errors_mean, pre_errors_std, pre_errors_max))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initial Alignment\n",
    "\n",
    "We use the CenteredTransformInitializer. Should we use the GEOMETRY based version or the MOMENTS based one?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "initial_transform = sitk.CenteredTransformInitializer(sitk.Cast(fixed_image,moving_image.GetPixelID()), \n",
    "                                                      moving_image, \n",
    "                                                      sitk.Euler3DTransform(), \n",
    "                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)\n",
    "\n",
    "initial_errors_mean, initial_errors_std, initial_errors_min, initial_errors_max, _ = ru.registration_errors(initial_transform, fixed_points, moving_points, min_err=pre_errors_min, max_err=pre_errors_max, display_errors=True)\n",
    "print('After initialization, errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(initial_errors_mean, initial_errors_std, initial_errors_max))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Registration\n",
    "\n",
    "Possible choices for simple rigid multi-modality registration framework (<b>300</b> component combinations, in addition to parameter settings for each of the components):\n",
    "<ul>\n",
    "<li>Similarity metric, 2 options (Mattes MI, JointHistogram MI):\n",
    "<ul>\n",
    "  <li>Number of histogram bins.</li>\n",
    "  <li>Sampling strategy, 3 options (NONE, REGULAR, RANDOM)</li>\n",
    "  <li>Sampling percentage.</li>\n",
    "</ul>\n",
    "</li>\n",
    "<li>Interpolator, 10 options (sitkNearestNeighbor, sitkLinear, sitkGaussian, sitkBSpline,...)</li>\n",
    "<li>Optimizer, 5 options (GradientDescent, GradientDescentLineSearch, RegularStepGradientDescent...): \n",
    "<ul>\n",
    "  <li>Number of iterations.</li>\n",
    "  <li>learning rate (step size along parameter space traversal direction).</li>\n",
    "</ul>\n",
    "</li>\n",
    "</ul>\n",
    "\n",
    "In this example we will plot the similarity metric's value and more importantly the TREs for our reference data. A good choice for the former should be reflected by the later. That is, the TREs should go down as the similarity measure value goes down (not necessarily at the same rates).\n",
    "\n",
    "Finally, we are also interested in timing our registration. IPython allows us to do this with minimal effort using the <a href=\"http://ipython.org/ipython-doc/stable/interactive/magics.html?highlight=timeit#magic-timeit\">timeit</a> cell magic (IPython has a set of predefined functions that use a command line syntax, and are referred to as magic functions). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#%%timeit -r1 -n1\n",
    "# to time this cell uncomment the line above\n",
    "#the arguments to the timeit magic specify that this cell should only be run once. running it multiple \n",
    "#times to get performance statistics is also possible, but takes time. if you want to analyze the accuracy \n",
    "#results from multiple runs you will have to modify the code to save them instead of just printing them out.\n",
    "\n",
    "registration_method = sitk.ImageRegistrationMethod()\n",
    "registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)\n",
    "registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)\n",
    "registration_method.SetMetricSamplingPercentage(0.01)\n",
    "registration_method.SetInterpolator(sitk.sitkNearestNeighbor) #2. Replace with sitkLinear\n",
    "registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100) #1. Increase to 1000\n",
    "registration_method.SetOptimizerScalesFromPhysicalShift() \n",
    "     \n",
    "# Don't optimize in-place, we would like to run this cell multiple times\n",
    "registration_method.SetInitialTransform(initial_transform, inPlace=False)\n",
    "\n",
    "# Add callbacks which will display the similarity measure value and the reference data during the registration process\n",
    "registration_method.AddCommand(sitk.sitkStartEvent, rc.metric_and_reference_start_plot)\n",
    "registration_method.AddCommand(sitk.sitkEndEvent, rc.metric_and_reference_end_plot)\n",
    "registration_method.AddCommand(sitk.sitkIterationEvent, lambda: rc.metric_and_reference_plot_values(registration_method, fixed_points, moving_points))\n",
    "\n",
    "final_transform_single_scale = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), \n",
    "                                                           sitk.Cast(moving_image, sitk.sitkFloat32))\n",
    "\n",
    "print('Final metric value: {0}'.format(registration_method.GetMetricValue()))\n",
    "print('Optimizer\\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))\n",
    "final_errors_mean, final_errors_std, _, final_errors_max,_ = ru.registration_errors(final_transform_single_scale, fixed_points, moving_points, min_err=initial_errors_min, max_err=initial_errors_max, display_errors=True)\n",
    "print('After registration, errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In some cases visual comparison of the registration errors using the same scale is not informative, as seen above [all points are gray/black]. We therefor set the color scale to the min-max error range found in the current data and not the range from the previous stage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "final_errors_mean, final_errors_std, _, final_errors_max,_ = ru.registration_errors(final_transform_single_scale, fixed_points, moving_points, display_errors=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now using the built in multi-resolution framework\n",
    "\n",
    "Perform registration using the same settings as above, but take advantage of the multi-resolution framework which provides a significant speedup with minimal effort (3 lines of code).\n",
    "\n",
    "It should be noted that when using this framework the similarity metric value will not necessarily decrease between resolutions, we are only ensured that it decreases per resolution. This is not an issue, as we are actually observing the values of a different function at each resolution. \n",
    "\n",
    "The example below shows that registration is improving even though the similarity value increases when changing resolution levels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%%timeit -r1 -n1\n",
    "#the arguments to the timeit magic specify that this cell should only be run once. running it multiple \n",
    "#times to get performance statistics is also possible, but takes time. if you want to analyze the accuracy \n",
    "#results from multiple runs you will have to modify the code to save them instead of just printing them out.\n",
    "\n",
    "registration_method = sitk.ImageRegistrationMethod()\n",
    "registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)\n",
    "registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)\n",
    "registration_method.SetMetricSamplingPercentage(0.1)\n",
    "registration_method.SetInterpolator(sitk.sitkLinear) #2. Replace with sitkLinear\n",
    "registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100) \n",
    "registration_method.SetOptimizerScalesFromPhysicalShift() \n",
    "     \n",
    "# Don't optimize in-place, we would like to run this cell multiple times\n",
    "registration_method.SetInitialTransform(initial_transform, inPlace=False)\n",
    "\n",
    "# Add callbacks which will display the similarity measure value and the reference data during the registration process\n",
    "registration_method.AddCommand(sitk.sitkStartEvent, rc.metric_and_reference_start_plot)\n",
    "registration_method.AddCommand(sitk.sitkEndEvent, rc.metric_and_reference_end_plot)\n",
    "registration_method.AddCommand(sitk.sitkIterationEvent, lambda: rc.metric_and_reference_plot_values(registration_method, fixed_points, moving_points))\n",
    "\n",
    "registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])\n",
    "registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])\n",
    "registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()\n",
    "\n",
    "final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), \n",
    "                                              sitk.Cast(moving_image, sitk.sitkFloat32))\n",
    "\n",
    "print('Final metric value: {0}'.format(registration_method.GetMetricValue()))\n",
    "print('Optimizer\\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))\n",
    "final_errors_mean, final_errors_std, _, final_errors_max,_ = ru.registration_errors(final_transform, fixed_points, moving_points, True)\n",
    "\n",
    "print('After registration, errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Sufficient accuracy <u>inside</u> the ROI \n",
    "\n",
    "Up to this point our accuracy evaluation has ignored the content of the image and is likely overly conservative. We have been looking at the registration errors inside the volume, but not necessarily in the smaller ROI.\n",
    "\n",
    "To see the difference you will have to <b>comment out the timeit magic in the code above</b>, run it again, and then run the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Threshold the original fixed, CT, image at 0HU (water), resulting in a binary labeled [0,1] image.\n",
    "roi = fixed_image> 0\n",
    "\n",
    "# Our ROI consists of all voxels with a value of 1, now get the bounding box surrounding the head.\n",
    "label_shape_analysis = sitk.LabelShapeStatisticsImageFilter()\n",
    "label_shape_analysis.SetBackgroundValue(0)\n",
    "label_shape_analysis.Execute(roi)\n",
    "bounding_box = label_shape_analysis.GetBoundingBox(1)\n",
    "\n",
    "# Bounding box in physical space.\n",
    "sub_image_min = fixed_image.TransformIndexToPhysicalPoint((bounding_box[0],bounding_box[1], bounding_box[2]))\n",
    "sub_image_max = fixed_image.TransformIndexToPhysicalPoint((bounding_box[0]+bounding_box[3]-1,\n",
    "                                                           bounding_box[1]+bounding_box[4]-1, \n",
    "                                                           bounding_box[2]+bounding_box[5]-1))\n",
    "# Only look at the points inside our bounding box.\n",
    "sub_fixed_points = []\n",
    "sub_moving_points = []\n",
    "for fixed_pnt, moving_pnt in zip(fixed_points, moving_points):\n",
    "    if sub_image_min[0]<=fixed_pnt[0]<=sub_image_max[0] and \\\n",
    "       sub_image_min[1]<=fixed_pnt[1]<=sub_image_max[1] and \\\n",
    "       sub_image_min[2]<=fixed_pnt[2]<=sub_image_max[2] : \n",
    "        sub_fixed_points.append(fixed_pnt)\n",
    "        sub_moving_points.append(moving_pnt)\n",
    "\n",
    "final_errors_mean, final_errors_std, _, final_errors_max,_ = ru.registration_errors(final_transform, sub_fixed_points, sub_moving_points, True)\n",
    "print('After registration, errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
